#pragma once

//定义颜色
enum Color
{
	RED,
	BLACK,
};

// 定义节点
template<class K, class V>
struct RBTreeNode
{
	pair<K, V> _kv;
	RBTreeNode<K, V>* _left;
	RBTreeNode<K, V>* _right;
	RBTreeNode<K, V>* _parent;
	Color _col;

	RBTreeNode(const pair<K, V>& kv)
		:_kv(kv)
		, _left(nullptr)
		, _right(nullptr)
		, _parent(nullptr)
		, _col(RED)
	{}
};

template<class K, class V>
class RBTree
{
	typedef RBTreeNode<K, V> Node;
public:
	bool Insert(const pair<K, V>& kv)
	{
		if (_root == nullptr)
		{
			_root = new Node(kv);
			_root->_col = BLACK;
			return true;
		}
		else
		{
			Node* cur = _root;
			Node* parent = nullptr;

			while (cur)
			{
				if (cur->_kv.first < kv.first)
				{
					parent = cur;
					cur = cur->_right;
				}
				else if (cur->_kv.first > kv.first)
				{
					parent = cur;
					cur = cur->_left;
				}
				else
				{
					return false;
				}
			}

			cur = new Node(kv);
			cur->_col = RED;

			if (parent->_kv.first < kv.first)
			{
				parent->_right = cur;
				cur->_parent = parent;
			}
			else
			{
				parent->_left = cur;
				cur->_parent = parent;
			}


			while (parent && parent->_col == RED)
			{
				Node* grandfather = parent->_parent;
				if (parent == grandfather->_left)
				{
					Node* uncle = grandfather->_right;

					// 情况一  uncle存在且为红
					if (uncle && uncle->_col == RED)
					{
						parent->_col = uncle->_col = BLACK;
						grandfather->_col = RED;

						cur = grandfather;
						parent = cur->_parent;
					}
					else
					{
						// 情况二
						if (cur == parent->_left)
						{
							RotateR(grandfather);

							parent->_col = BLACK;
							grandfather->_col = RED;
						}
						// 情况三
						else
						{
							RotateL(parent);
							RotateR(grandfather);

							cur->_col = BLACK;
							grandfather->_col = RED;
						}

						break;
					}
				}
				else
				{
					Node* uncle = grandfather->_left;
					if (uncle && uncle->_col == RED)
					{
						uncle->_col = parent->_col = BLACK;
						grandfather->_col = RED;

						cur = grandfather;
						parent = cur->_parent;
					}
					else
					{
						// g
						//     p
						// c
						if (cur == parent->_left)
						{
							RotateR(parent);
							RotateL(grandfather);

							cur->_col = BLACK;
							grandfather->_col = RED;
						}
						// g
						//    p
						//      c
						else
						{
							RotateL(grandfather);

							parent->_col = BLACK;
							grandfather->_col = RED;
						}

						break;
					}
				}
			}

			_root->_col = BLACK;

			return true;
		}
	}
	void RotateL(Node* parent)
	{
		Node* SubR = parent->_right;
		Node* SubRL = SubR->_left;

		parent->_right = SubRL;
		if (SubRL)
			SubRL->_parent = parent;

		Node* ppNode = parent->_parent;

		SubR->_left = parent;
		parent->_parent = SubR;

		if (ppNode == nullptr)
		{
			_root = SubR;
			SubR->_parent = nullptr;
		}
		else
		{
			if (parent == ppNode->_left)
			{
				ppNode->_left = SubR;
				SubR->_parent = ppNode;
			}
			else
			{
				ppNode->_right = SubR;
				SubR->_parent = ppNode;
			}
		}
	}

	void RotateR(Node* parent)
	{
		Node* SubL = parent->_left;
		Node* SubLR = SubL->_right;

		parent->_left = SubLR;
		if (SubLR)
			SubLR->_parent = parent;

		Node* ppNode = parent->_parent;
		SubL->_right = parent;
		parent->_parent = SubL;

		if (ppNode == nullptr)
		{
			_root = SubL;
			SubL->_parent = nullptr;
		}
		else
		{
			if (parent == ppNode->_left)
			{
				ppNode->_left = SubL;
			}
			else
			{
				ppNode->_right = SubL;
			}

			SubL->_parent = ppNode;
		}
	}

	void InOrder()
	{
		_InOrder(_root);
	}

	void _InOrder(Node* root)
	{
		if (root == nullptr)
			return;

		_InOrder(root->_left);
		cout << root->_kv.first << ":" << root->_kv.second << endl;
		_InOrder(root->_right);
	}

	bool check(Node* root, int blackNum, int ref)
	{
		if (root == nullptr)
		{
			if (blackNum != ref)
			{
				cout << "违反规则:本条路径的黑色节点的数量跟最左路径不相等" << endl;
				return false;
			}

			return true;
		}

		if (root->_col == RED && root->_parent->_col == RED)
		{
			cout << "违反规则：出现连续红色节点" << endl;
			return false;
		}

		if (root->_col == BLACK)
			blackNum++;

		return check(root->_left, blackNum, ref)
			&& check(root->_right, blackNum, ref);
	}


	bool IsBalance()
	{
		if (_root == nullptr)
			return true;

		if (_root->_col != BLACK)
			return false;

		int ref = 0;// 统计黑节点的个数
		Node* left = _root;
		while (left)
		{

			if (left->_col == BLACK)
				ref++;

			left = left->_left;
		}

		return check(_root, 0, ref);
	}
private:
	Node* _root = nullptr;
};
